import logging
import threading
from http.server import HTTPServer
from ipaddress import IPv4Address
from typing import Callable, Optional, Type
from uuid import uuid4

from common import OperatingSystem
from common.types import Event, Lock, NetworkPort
from common.utils.code_utils import insecure_generate_random_string
from infection_monkey.network import TCPPortSelector
from infection_monkey.network.tools import get_interface_to_target
from infection_monkey.utils.threading import create_daemon_thread

from .agent_binary_request import (
    AgentBinaryDownloadReservation,
    AgentBinaryDownloadTicket,
    AgentBinaryTransform,
    ReservationID,
)
from .http_agent_binary_request_handler import AgentBinaryHTTPRequestHandler

logger = logging.getLogger(__name__)

AgentBinaryHTTPHandlerFactory = Callable[[], Type[AgentBinaryHTTPRequestHandler]]


def use_agent_binary(agent_binary: bytes) -> bytes:
    return agent_binary


class HTTPAgentBinaryServer:
    """
    Serves Agent binaries over HTTP

    Allows clients to register for an Agent binary to be served. The server will serve the
    requested binary until it is deregistered or the server is stopped.

    :param tcp_port_selector: The TCP port selector to use
    :param get_handler_class: A function that returns the HTTP handler class to use
    :param create_event: A function that the server will use to create events
    :param lock: A lock to use
    :param poll_interval: The interval to poll for server shutdown, in seconds
    """

    def __init__(
        self,
        tcp_port_selector: TCPPortSelector,
        get_handler_class: AgentBinaryHTTPHandlerFactory,
        create_event: Callable[[], Event],
        lock: Lock,
        poll_interval: float = 0.5,
    ):
        self._tcp_port_selector = tcp_port_selector
        self._handler_class = get_handler_class()
        self._create_event = create_event
        self._lock = lock
        self._poll_interval = poll_interval
        self._port: Optional[NetworkPort] = None
        self._server: Optional[HTTPServer] = None
        self._server_thread: Optional[threading.Thread] = None

    def register(
        self,
        operating_system: OperatingSystem,
        requestor_ip: IPv4Address,
        agent_binary_transform: AgentBinaryTransform = use_agent_binary,
    ) -> AgentBinaryDownloadTicket:
        """
        Register to download an Agent binary

        If the server is not running, it will be started.

        :param operating_system: The operating system for the Agent binary to serve
        :param requestor_ip: The IP address of the client that will download the Agent binary
        :param agent_binary_transform: A callable that transforms the Agent binary before serving.
            This may be used to, e.g., convert the binary into a self-extracting shell script.
            Defaults to no-op
        :raises RuntimeError: If the binary could not be served
        :raises Exception: If the server failed to start
        :returns: A ticket to download the Agent binary
        """
        with self._lock:
            if not self.server_is_running():
                self._start_server()

            reservation_id = uuid4()
            url = self._build_request_url(reservation_id, operating_system, requestor_ip)
            reservation = AgentBinaryDownloadReservation(
                reservation_id,
                operating_system,
                agent_binary_transform,
                url,
                self._create_event(),
            )
            self._handler_class.reserve_download(reservation)

            return AgentBinaryDownloadTicket(reservation_id, url, reservation.download_completed)

    def _build_request_url(
        self,
        reservation_id: ReservationID,
        operating_system: OperatingSystem,
        requestor_ip: IPv4Address,
    ) -> str:
        server_ip = get_interface_to_target(str(requestor_ip))
        return f"http://{server_ip}:{self._port}/{operating_system.value}/{reservation_id}"

    def server_is_running(self) -> bool:
        return self._server_thread is not None and self._server_thread.is_alive()

    def _start_server(self):
        if self._server is None:
            self._server = self._create_server()
        if self._server_thread is None:
            self._server_thread = self._create_server_thread(self._server)
            self._server_thread.start()

    def _create_server(self) -> HTTPServer:
        self._port = self._tcp_port_selector.get_free_tcp_port(
            # Allow 443, 80 in the future?
            preferred_ports=list(map(NetworkPort, [8080, 8008, 8000, 8443]))
        )
        if self._port is None:
            raise RuntimeError("Could not find a free TCP port to serve Agent binaries")

        return HTTPServer(("0.0.0.0", int(self._port)), self._handler_class)

    def _create_server_thread(self, server: HTTPServer) -> threading.Thread:
        thread_name = f"HTTPAgentBinaryServer-{insecure_generate_random_string(n=8)}"
        return create_daemon_thread(
            target=server.serve_forever,
            name=thread_name,
            args=(self._poll_interval,),
        )

    def deregister(self, reservation_id: ReservationID) -> None:
        """
        Deregister an Agent binary from being served

        :param reservation_id: The ID of the reservation to deregister
        :raises KeyError: If the reservation ID is not registered
        """
        with self._lock:
            self._handler_class.clear_reservation(reservation_id)

    def start(self):
        """
        Start the server

        :raises Exception: If the server failed to start
        """
        if not self.server_is_running():
            logger.debug("Starting the HTTP server")
            self._start_server()

    def stop(self, timeout: Optional[float] = None):
        """
        Stop the server

        :param timeout: The maximum amount of time to wait for the server to stop, in seconds. If
            not provided or set to None, it will block until the server shuts down
        """
        if self._server is None or self._server_thread is None:
            return

        if self._server_thread.is_alive():
            logger.debug("Stopping the HTTP server")
            self._server.shutdown()
            self._server_thread.join(timeout)

        if self._server_thread.is_alive():
            logger.warning("Timed out waiting for HTTP server to stop")
        else:
            logger.debug("The HTTP server has stopped")
